clc;clear all;close all;
load('solN2_dur.mat')

ALPHall=[0.4861 0.48 0.475 0.473 0.472 0.46 0.455 0.453 0.452 0.451 0.44 0.435 0.433 0.432 0.418 0.417 0.416 0.415 0.414 0.4 0.399 0.398 0.385 0.383 0.382 0.37 0.369 0.368 0.356 0.355 0.347 0.3467 0.3 0.25 0.23 0.225]

for ii=1:length(ALPHall)

ALPHA=ALPHall(ii)

X=XX(ii-1,:);
    
tauK0 = taumax;
tolf0 = 1e-7;

options = optimset('MaxFunEvals',10000000,'MaxIter',100,'TolFun',1e-15,'TolX',1e-15,'Disp','iter');
[X,check] = fsolve('resid3v',X,options)
sum(resid3v(X).^2)

XX(ii,:)=X;

%for ii=1:length(ALPHall)
%X=XX(ii,:);
k = X(1:T);
l = X(T+1:2*T);
gamma = X(2*T+1:3*T);
Delta1 = X(3*T+1);
Delta2 = X(3*T+2:3*T+2+N-2);
lamda = X(3*T+3+N-2:3*T+3+2*(N-2));
Delta1*kistart(1)+sum(Delta2.*kistart(2:N))
kstst = fzero('findstst',[kstart-0.2 kmax],options)%,0,0)

% save('solN2.mat')
% clc;clear all;close all;
% load('solN2.mat')

rstst = betta^(-1)-1+depr; %r from Euler at the steady state
estst = (rstst/alpha)^(1/(1-alpha))*kstst; % express total effective labor from r=F_k
for jj=2:N
ljpart(jj-1)=popweight(jj)*phi(jj)*lam(jj-1)^(-sigmaC/sigmaL)*(phi(jj)/phi(1))^(1/sigmaL);
end
lstst=estst/(popweight(1)*phi(1)+sum(ljpart)); %express l1 using the expression for total effective labor and l2=f(lam,l1)
for jj=2:N
    L2a(jj-1) = lamda(jj-1)^(-sigmaC/sigmaL)*(phi(jj)/phi(1))^(1/sigmaL); %last terms in (11) and f_l
end
wstst = (1-alpha)*kstst^alpha*estst^(-alpha); %w=F_e
cstst = 1/(popweight(1)+sum(popweight(2:N).*lamda))*(kstst.^alpha*estst.^(1-alpha) - depr*kstst - govexp); %c from resource constraint
mustst = omega*lstst^sigmaL*(1+sum(ALPHA.*lamda.^(-sigmaC).*phi(2:N)'/phi(1).*L2a)+(Delta1+sum(Delta2.*phi(2:N)'/phi(1).*L2a))*(1+sigmaL))/(wstst*(popweight(1)*phi(1)+sum(ljpart)));

klast = [kstart X(1:T-1)]; %k_{t-1}
e=[l' (l'*(lamda.^(-sigmaC/sigmaL).*(phi(2:N)/phi(1))'.^(1/sigmaL)))]*(popweight'.*phi);
e=e';

r = alpha*klast.^(alpha-1).*e.^(1-alpha); %r=F_k
w = (1-alpha)*klast.^(alpha).*e.^(-alpha); %w=F_e
c = 1/(popweight(1)+sum(popweight(2:N).*lamda))*(klast.^alpha.*e.^(1-alpha)+ (1-depr)*klast - k - govexp); %resource constraint

cnext = [c(2:end) cstst];
rnext = [r(2:end) rstst];
lnext = [X(T+2:2*T) lstst];

%check bound on tauk
constrtest = c.^(-sigmaC)./cnext.^(-sigmaC) - betta*(1+ (rnext-depr)*(1-taumax));

cistst = [cstst lamda*cstst]; %consumption of groups in the long run
listst = [lstst, L2(lamda,lstst,phi(2:N))']; %hours worked of groups in the long run
for tt=1:T
ci(tt,:) = [c(tt) lamda*c(tt)]; %consumption of groups during the transition
li(tt,:) = [l(tt), L2(lamda,l(tt),phi(2:N))']; %hours worked of groups during the transition
end

%lifetime utilities
if (sigmaC==1)
    V = sum(betta.^[0:T-1]'*ones(1,N).*(log(ci)- omega*li.^(1+sigmaL)/(1+sigmaL)))...
    + betta^(T)/(1-betta)*(log(cistst) - omega*listst.^(1+sigmaL)/(1+sigmaL));
else
    V = sum(betta.^[0:T-1]'*ones(1,N).*((ci.^(1-sigmaC)-1)/(1-sigmaC) - omega*li.^(1+sigmaL)/(1+sigmaL)))...
    + betta^(T)/(1-betta)*((cistst.^(1-sigmaC)-1)/(1-sigmaC) - omega*listst.^(1+sigmaL)/(1+sigmaL));
end

if sum(V>=Vsq)==2
    PI(ii)=1
else
    PI(ii)=0
end

residall(ii)=sum(resid3v(X).^2)
Vall(ii,:)=V
Vsq

%tax rates
tauKt = 1 - ((c(1:end-1)./c(2:end)).^(-sigmaC)/betta-1)./(r(2:end)-depr); %Euler
tauKt = [tauK0 tauKt];
plot(tauKt)

PI(ii)
tauKt
[swi dur]=min((tauKt-taumax/2).^2);
duration(ii)=dur
lamda_all(ii)=lamda;
Delta1_all(ii)=Delta1;
Delta2_all(ii)=Delta2;
term_all(ii)=1+ALPHA*lamda^(1-sigmaC)+(Delta1+Delta2*lamda)*(1-sigmaC);
kappa=lamda^(-sigmaC/sigmaL)*(phi(2)/phi(1))^(1/sigmaL);
terml_all(ii)=1+ALPHA*kappa^(1+sigmaL)+(Delta1+phi(2)/phi(1)*kappa*Delta2)*(1+sigmaL);
cst_all(ii)=cstst;
gammast_all(ii)=gamma(T);
must_all(ii)=mustst;

%foc for labor
klam=lamda^(-sigmaC/sigmaL)*(phi(2)/phi(1))^(1/sigmaL);
lamderl(ii)=-lamda^(-sigmaC/sigmaL)*(phi(2)/phi(1))^((1+sigmaL)/sigmaL)*betta^T*(1+sigmaL)*lstst^(sigmaL)/(sum(betta.^[0:T-1].*c.^(-sigmaC).*c)+betta^T/(1-betta)*cstst^(-sigmaC)*cstst...
    -sigmaC/sigmaL*lamda^(-(sigmaC+sigmaL)/sigmaL)*(phi(2)/phi(1))^((1+sigmaL)/sigmaL)*(sum(betta.^[0:T-1].*(-omega*l.^(sigmaL).*l))+betta^T/(1-betta)*(-omega)*lstst^sigmaL*lstst));
%foc for consumption
lamderc(ii)=-lamda*betta^T*(1-sigmaC)*cstst^(-sigmaC)/(sum(betta.^[0:T-1].*c.^(-sigmaC).*c)+betta^T/(1-betta)*cstst^(-sigmaC)*cstst...
    -sigmaC/sigmaL*lamda^(-(sigmaC+sigmaL)/sigmaL)*(phi(2)/phi(1))^((1+sigmaL)/sigmaL)*(sum(betta.^[0:T-1].*(-omega*l.^(sigmaL).*l))+betta^T/(1-betta)*(-omega)*lstst^sigmaL*lstst));
%end

save('solN2PO.mat')

end

kistart = flip(kj);
phi=flip(phi);

ALPHall2=[2 1.5 1.2 1 0.7 0.5 0.3 0.2 0.1 0.06 0.02 0]

X=XX(1,:)
lamda=2;
Delta1 = -0.1;
Delta2 = 0.3;
%initial guesses 
for i=1:T;
   X(T+i) = L2(1/lamda,X(T+i),phi(1)); %l    
end;
X(3*T+1:3*T+1+2*(N-1)) = [Delta1 Delta2 lamda];

for ii=1:length(ALPHall2)

ALPHA=ALPHall2(ii)

if ii>1
    X=XX2(ii-1,:);
end
X(3*T+3+N-2:3*T+3+2*(N-2))
X(3*T+1)

if ii==5
    X(3*T+3+N-2:3*T+3+2*(N-2))=1.8; %lamda
    X(3*T+1)=-0.2; %Delta1
    %X(3*T+2:3*T+2+N-2); %Delta2
end

if ii==7
    X(3*T+3+N-2:3*T+3+2*(N-2))=1.72; %lamda
    X(3*T+1)=-0.34; %Delta1
end

if ii==8
    X(3*T+3+N-2:3*T+3+2*(N-2))=1.74; %lamda
    X(3*T+1)=-0.36; %Delta1
end

if ii==9
    X(3*T+3+N-2:3*T+3+2*(N-2))=1.73; %lamda
    X(3*T+1)=-0.38; %Delta1
end

if ii==10
    X(3*T+3+N-2:3*T+3+2*(N-2))=1.735; %lamda
    X(3*T+1)=-0.4; %Delta1
end

if ii==11
    X(3*T+3+N-2:3*T+3+2*(N-2))=1.73; %lamda
    X(3*T+1)=-0.43; %Delta1
end

if ii==12
    X(3*T+3+N-2:3*T+3+2*(N-2))=1.705; %lamda
    X(3*T+1)=-0.415; %Delta1
end

tauK0 = taumax;
tolf0 = 1e-7;

if ii==4 | ii>=6
    [X,check] = broydn('resid3v',X,tolf0)
    if size(X,1)>1
        X=X';
    end
    sum(resid3v(X).^2)
end

options = optimset('MaxFunEvals',10000000,'MaxIter',100,'TolFun',1e-15,'TolX',1e-15,'Disp','iter');
[X,check] = fsolve('resid3v',X,options)
X1=X;
sum(resid3v(X1).^2)
X=X1;

if ii==1
    for i=T-5:T;
        X(i) = (X(max(i,2)-1)+X(i)+X(min(i,T-1)+1))/3; %k
    end;
end
if ii==6
    for i=55:T;
        X(i) = (X(max(i,2)-1)+X(i)+X(min(i,T-1)+1))/3; %k
    end;
end

if ii==13
    for i=60:T;
        X(i) = (X(max(i,2)-1)+X(i)+X(min(i,T-1)+1))/3; %k
    end;
end
 
if ii==1 | ii>=5
    [X,check] = broydn('resid3v',X,tolf0)
    if size(X,1)>1
        X=X';
    end
    sum(resid3v(X).^2)

    options = optimset('MaxFunEvals',10000000,'MaxIter',100,'TolFun',1e-15,'TolX',1e-15,'Disp','iter');
    [X,check] = fsolve('resid3v',X,options)
    sum(resid3v(X).^2)
    X2=X;
end

if ii==8
    for i=T-12:T;
        X(i) = (X(max(i,2)-1)+X(i)+X(min(i,T-1)+1))/3; %k
    end
    [X,check] = broydn('resid3v',X,tolf0)
    if size(X,1)>1
        X=X';
    end
    sum(resid3v(X).^2)

    options = optimset('MaxFunEvals',10000000,'MaxIter',100,'TolFun',1e-15,'TolX',1e-15,'Disp','iter');
    [X,check] = fsolve('resid3v',X,options)
    sum(resid3v(X).^2)
     [X,check] = broydn('resid3v',X,tolf0)
    if size(X,1)>1
        X=X';
    end
    sum(resid3v(X).^2)

    options = optimset('MaxFunEvals',10000000,'MaxIter',100,'TolFun',1e-15,'TolX',1e-15,'Disp','iter');
    [X,check] = fsolve('resid3v',X,options)
    sum(resid3v(X).^2)
end

if ii==15
    options = optimset('MaxFunEvals',10000000,'MaxIter',100,'TolFun',1e-15,'TolX',1e-15,'Disp','iter');
    [X,check] = fsolve('resid3v',X,options)
    sum(resid3v(X).^2)
end
   
if ii==3 | ii==4
    [X,check] = broydn('resid3v',X,tolf0)
    if size(X,1)>1
        X=X';
    end
    sum(resid3v(X).^2)
    
    for i=1:T;
        X(i) = (X(max(i,2)-1)+X(i)+X(min(i,T-1)+1))/3; %k
    end;
    
    options = optimset('MaxFunEvals',10000000,'MaxIter',100,'TolFun',1e-15,'TolX',1e-15,'Disp','iter');
    [X,check] = fsolve('resid3v',X,options)
    X2=X;
    sum(resid3v(X2).^2)

    [X,check] = broydn('resid3v',X,tolf0)
    if size(X,1)>1
        X=X';
    end
    sum(resid3v(X).^2)
    
    options = optimset('MaxFunEvals',10000000,'MaxIter',100,'TolFun',1e-15,'TolX',1e-15,'Disp','iter');
    [X,check] = fsolve('resid3v',X,options)
    X3=X;
    sum(resid3v(X3).^2)
end

XX2(ii,:)=X;

%for ii=1:length(ALPHall2)
%X=XX2(ii,:);
k = X(1:T);
l = X(T+1:2*T);
gamma = X(2*T+1:3*T);
Delta1 = X(3*T+1);
Delta2 = X(3*T+2:3*T+2+N-2);
lamda = X(3*T+3+N-2:3*T+3+2*(N-2));
Delta1*kistart(1)+sum(Delta2.*kistart(2:N))
kstst = fzero('findstst',[kstart-0.2 kmax],options)%,0,0)

% save('solN2.mat')
% clc;clear all2;close all2;
% load('solN2.mat')

rstst = betta^(-1)-1+depr; %r from Euler at the steady state
estst = (rstst/alpha)^(1/(1-alpha))*kstst; % express total effective labor from r=F_k
for jj=2:N
ljpart(jj-1)=popweight(jj)*phi(jj)*lam(jj-1)^(-sigmaC/sigmaL)*(phi(jj)/phi(1))^(1/sigmaL);
end
lstst=estst/(popweight(1)*phi(1)+sum(ljpart)); %express l1 using the expression for total effective labor and l2=f(lam,l1)
for jj=2:N
    L2a(jj-1) = lamda(jj-1)^(-sigmaC/sigmaL)*(phi(jj)/phi(1))^(1/sigmaL); %last terms in (11) and f_l
end
wstst = (1-alpha)*kstst^alpha*estst^(-alpha); %w=F_e
cstst = 1/(popweight(1)+sum(popweight(2:N).*lamda))*(kstst.^alpha*estst.^(1-alpha) - depr*kstst - govexp); %c from resource constraint
mustst = omega*lstst^sigmaL*(1+sum(ALPHA.*lamda.^(-sigmaC).*phi(2:N)'/phi(1).*L2a)+(Delta1+sum(Delta2.*phi(2:N)'/phi(1).*L2a))*(1+sigmaL))/(wstst*(popweight(1)*phi(1)+sum(ljpart)));

klast = [kstart X(1:T-1)]; %k_{t-1}
e=[l' (l'*(lamda.^(-sigmaC/sigmaL).*(phi(2:N)/phi(1))'.^(1/sigmaL)))]*(popweight'.*phi);
e=e';

r = alpha*klast.^(alpha-1).*e.^(1-alpha); %r=F_k
w = (1-alpha)*klast.^(alpha).*e.^(-alpha); %w=F_e
c = 1/(popweight(1)+sum(popweight(2:N).*lamda))*(klast.^alpha.*e.^(1-alpha)+ (1-depr)*klast - k - govexp); %resource constraint

cnext = [c(2:end) cstst];
rnext = [r(2:end) rstst];
lnext = [X(T+2:2*T) lstst];

%check bound on tauk
constrtest = c.^(-sigmaC)./cnext.^(-sigmaC) - betta*(1+ (rnext-depr)*(1-taumax));

cistst = [cstst lamda*cstst]; %consumption of groups in the long run
listst = [lstst, L2(lamda,lstst,phi(2:N))']; %hours worked of groups in the long run
for tt=1:T
ci(tt,:) = [c(tt) lamda*c(tt)]; %consumption of groups during the transition
li(tt,:) = [l(tt), L2(lamda,l(tt),phi(2:N))']; %hours worked of groups during the transition
end

%lifetime utilities
if (sigmaC==1)
    V = sum(betta.^[0:T-1]'*ones(1,N).*(log(ci)- omega*li.^(1+sigmaL)/(1+sigmaL)))...
    + betta^(T)/(1-betta)*(log(cistst) - omega*listst.^(1+sigmaL)/(1+sigmaL));
else
    V = sum(betta.^[0:T-1]'*ones(1,N).*((ci.^(1-sigmaC)-1)/(1-sigmaC) - omega*li.^(1+sigmaL)/(1+sigmaL)))...
    + betta^(T)/(1-betta)*((cistst.^(1-sigmaC)-1)/(1-sigmaC) - omega*listst.^(1+sigmaL)/(1+sigmaL));
end

if sum(V>=Vsq)==2
    PI(ii)=1
else
    PI(ii)=0
end

residall2(ii)=sum(resid3v(X).^2)
Vall2(ii,:)=V
Vsq

%tax rates
tauKt = 1 - ((c(1:end-1)./c(2:end)).^(-sigmaC)/betta-1)./(r(2:end)-depr); %Euler
tauKt = [tauK0 tauKt];
plot(tauKt)

PI(ii)
tauKt
[swi dur]=min((tauKt-taumax/2).^2);
duration2(ii)=dur;
lamda_all2(ii)=lamda;
Delta1_all2(ii)=Delta1;
Delta2_all2(ii)=Delta2;
term_all2(ii)=1+ALPHA*lamda^(1-sigmaC)+(Delta1+Delta2*lamda)*(1-sigmaC);
kappa=lamda^(-sigmaC/sigmaL)*(phi(2)/phi(1))^(1/sigmaL);
terml_all2(ii)=1+ALPHA*kappa^(1+sigmaL)+(Delta1+phi(2)/phi(1)*kappa*Delta2)*(1+sigmaL);
cst_all2(ii)=cstst;
gammast_all2(ii)=gamma(T);
must_all2(ii)=mustst;

%foc for labor
klam=lamda^(-sigmaC/sigmaL)*(phi(2)/phi(1))^(1/sigmaL);
lamderl(ii)=-lamda^(-sigmaC/sigmaL)*(phi(2)/phi(1))^((1+sigmaL)/sigmaL)*betta^T*(1+sigmaL)*lstst^(sigmaL)/(sum(betta.^[0:T-1].*c.^(-sigmaC).*c)+betta^T/(1-betta)*cstst^(-sigmaC)*cstst...
    -sigmaC/sigmaL*lamda^(-(sigmaC+sigmaL)/sigmaL)*(phi(2)/phi(1))^((1+sigmaL)/sigmaL)*(sum(betta.^[0:T-1].*(-omega*l.^(sigmaL).*l))+betta^T/(1-betta)*(-omega)*lstst^sigmaL*lstst));
%foc for consumption
lamderc(ii)=-lamda*betta^T*(1-sigmaC)*cstst^(-sigmaC)/(sum(betta.^[0:T-1].*c.^(-sigmaC).*c)+betta^T/(1-betta)*cstst^(-sigmaC)*cstst...
    -sigmaC/sigmaL*lamda^(-(sigmaC+sigmaL)/sigmaL)*(phi(2)/phi(1))^((1+sigmaL)/sigmaL)*(sum(betta.^[0:T-1].*(-omega*l.^(sigmaL).*l))+betta^T/(1-betta)*(-omega)*lstst^sigmaL*lstst));
%end

Vall2
flip(Vsq)
save('solN2PO.mat')

end

Vcap=Vall(:,1);
Vwor=Vall(:,2);

term1=(1-sigmaC)*((1-betta)*Vcap+omega*listst_sq(1)^(1+sigmaL)/(1+sigmaL));
epscap=((term1+1).^(1/(1-sigmaC))/cistst_sq(1)-1)*100;
term2=(1-sigmaC)*((1-betta)*Vwor+omega*listst_sq(2)^(1+sigmaL)/(1+sigmaL));
epswor=((term2+1).^(1/(1-sigmaC))/cistst_sq(2)-1)*100;

Vcap=Vall2(:,2);
Vwor=Vall2(:,1);

term1=(1-sigmaC)*((1-betta)*Vcap+omega*listst_sq(1)^(1+sigmaL)/(1+sigmaL));
epscap2=((term1+1).^(1/(1-sigmaC))/cistst_sq(1)-1)*100;
term2=(1-sigmaC)*((1-betta)*Vwor+omega*listst_sq(2)^(1+sigmaL)/(1+sigmaL));
epswor2=((term2+1).^(1/(1-sigmaC))/cistst_sq(2)-1)*100;

epscap_opt=[flip(epscap2)' epscap']
epswor_opt=[flip(epswor2)' epswor']

save('solN2PO.mat')

clc;clear all;close all;
load('solN2PO_taumax.mat')
load('solN2PO.mat')

plot(epswor_opt,epscap_opt,'-k','LineWidth',2.5);
xlim([-20 15])
hold on
plot(epswor_taumax,epscap_taumax,'-.k','LineWidth',2.5);
xline(0,':k','LineWidth',1.5)
yline(0,':k','LineWidth',1.5)
xlabel("workers' welfare increase (percent)",'FontSize',12)
ylabel("capitalists' welfare increase (percent)",'FontSize',12)
legend('PO','PO $\tau^k_t=\tilde{\tau}, \forall t$','Interpreter','latex')


